{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\nHow to Use TVM Pass Instrument\n==============================\n**Author**: `Chi-Wei Wang <https://github.com/chiwwang>`_\n\nAs more and more passes are implemented, it becomes useful to instrument\npass execution, analyze per-pass effects, and observe various events.\n\nWe can instrument passes by providing a list of :py:class:`tvm.ir.instrument.PassInstrument`\ninstances to :py:class:`tvm.transform.PassContext`. We provide a pass instrument\nfor collecting timing information (:py:class:`tvm.ir.instrument.PassTimingInstrument`),\nbut an extension mechanism is available via the :py:func:`tvm.instrument.pass_instrument` decorator.\n\nThis tutorial demostrates how developers can use ``PassContext`` to instrument\npasses. Please also refer to the `pass-infra`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\nimport tvm.relay as relay\nfrom tvm.relay.testing import resnet\nfrom tvm.contrib.download import download_testdata\nfrom tvm.relay.build_module import bind_params_by_name\nfrom tvm.ir.instrument import (\n    PassTimingInstrument,\n    pass_instrument,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Create An Example Relay Program\n-------------------------------\nWe use pre-defined resnet-18 network in Relay.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "batch_size = 1\nnum_of_image_class = 1000\nimage_shape = (3, 224, 224)\noutput_shape = (batch_size, num_of_image_class)\nrelay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape)\nprint(\"Printing the IR module...\")\nprint(relay_mod.astext(show_meta_data=False))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Create PassContext With Instruments\n-----------------------------------\nTo run all passes with an instrument, pass it via the ``instruments`` argument to\nthe ``PassContext`` constructor. A built-in ``PassTimingInstrument`` is used to\nprofile the execution time of each passes.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "timing_inst = PassTimingInstrument()\nwith tvm.transform.PassContext(instruments=[timing_inst]):\n    relay_mod = relay.transform.InferType()(relay_mod)\n    relay_mod = relay.transform.FoldScaleAxis()(relay_mod)\n    # before exiting the context, get profile results.\n    profiles = timing_inst.render()\nprint(\"Printing results of timing profile...\")\nprint(profiles)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Use Current PassContext With Instruments\n----------------------------------------\nOne can also use the current ``PassContext`` and register\n``PassInstrument`` instances by ``override_instruments`` method.\nNote that ``override_instruments`` executes ``exit_pass_ctx`` method\nif any instrument already exists. Then it switches to new instruments\nand calls ``enter_pass_ctx`` method of new instruments.\nRefer to following sections and :py:func:`tvm.instrument.pass_instrument` for these methods.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cur_pass_ctx = tvm.transform.PassContext.current()\ncur_pass_ctx.override_instruments([timing_inst])\nrelay_mod = relay.transform.InferType()(relay_mod)\nrelay_mod = relay.transform.FoldScaleAxis()(relay_mod)\nprofiles = timing_inst.render()\nprint(\"Printing results of timing profile...\")\nprint(profiles)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Register empty list to clear existing instruments.\n\nNote that ``exit_pass_ctx`` of ``PassTimingInstrument`` is called.\nProfiles are cleared so nothing is printed.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cur_pass_ctx.override_instruments([])\n# Uncomment the call to .render() to see a warning like:\n# Warning: no passes have been profiled, did you enable pass profiling?\n# profiles = timing_inst.render()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Create Customized Instrument Class\n----------------------------------\nA customized instrument class can be created using the\n:py:func:`tvm.instrument.pass_instrument` decorator.\n\nLet's create an instrument class which calculates the change in number of\noccurrences of each operator caused by each pass. We can look at ``op.name`` to\nfind the name of each operator. And we do this before and after passes to calculate the difference.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@pass_instrument\nclass RelayCallNodeDiffer:\n    def __init__(self):\n        self._op_diff = []\n        # Passes can be nested.\n        # Use stack to make sure we get correct before/after pairs.\n        self._op_cnt_before_stack = []\n\n    def enter_pass_ctx(self):\n        self._op_diff = []\n        self._op_cnt_before_stack = []\n\n    def exit_pass_ctx(self):\n        assert len(self._op_cnt_before_stack) == 0, \"The stack is not empty. Something wrong.\"\n\n    def run_before_pass(self, mod, info):\n        self._op_cnt_before_stack.append((info.name, self._count_nodes(mod)))\n\n    def run_after_pass(self, mod, info):\n        # Pop out the latest recorded pass.\n        name_before, op_to_cnt_before = self._op_cnt_before_stack.pop()\n        assert name_before == info.name, \"name_before: {}, info.name: {} doesn't match\".format(\n            name_before, info.name\n        )\n        cur_depth = len(self._op_cnt_before_stack)\n        op_to_cnt_after = self._count_nodes(mod)\n        op_diff = self._diff(op_to_cnt_after, op_to_cnt_before)\n        # only record passes causing differences.\n        if op_diff:\n            self._op_diff.append((cur_depth, info.name, op_diff))\n\n    def get_pass_to_op_diff(self):\n        \"\"\"\n        return [\n          (depth, pass_name, {op_name: diff_num, ...}), ...\n        ]\n        \"\"\"\n        return self._op_diff\n\n    @staticmethod\n    def _count_nodes(mod):\n        \"\"\"Count the number of occurrences of each operator in the module\"\"\"\n        ret = {}\n\n        def visit(node):\n            if isinstance(node, relay.expr.Call):\n                if hasattr(node.op, \"name\"):\n                    op_name = node.op.name\n                else:\n                    # Some CallNode may not have 'name' such as relay.Function\n                    return\n                ret[op_name] = ret.get(op_name, 0) + 1\n\n        relay.analysis.post_order_visit(mod[\"main\"], visit)\n        return ret\n\n    @staticmethod\n    def _diff(d_after, d_before):\n        \"\"\"Calculate the difference of two dictionary along their keys.\n        The result is values in d_after minus values in d_before.\n        \"\"\"\n        ret = {}\n        key_after, key_before = set(d_after), set(d_before)\n        for k in key_before & key_after:\n            tmp = d_after[k] - d_before[k]\n            if tmp:\n                ret[k] = d_after[k] - d_before[k]\n        for k in key_after - key_before:\n            ret[k] = d_after[k]\n        for k in key_before - key_after:\n            ret[k] = -d_before[k]\n        return ret"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Apply Passes and Multiple Instrument Classes\n--------------------------------------------\nWe can use multiple instrument classes in a ``PassContext``.\nHowever, it should be noted that instrument methods are executed sequentially,\nobeying the order of ``instruments`` argument.\nSo for instrument classes like ``PassTimingInstrument``, it is inevitable to\ncount-up the execution time of other instrument classes to the final\nprofile result.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "call_node_inst = RelayCallNodeDiffer()\ndesired_layouts = {\n    \"nn.conv2d\": [\"NHWC\", \"HWIO\"],\n}\npass_seq = tvm.transform.Sequential(\n    [\n        relay.transform.FoldConstant(),\n        relay.transform.ConvertLayout(desired_layouts),\n        relay.transform.FoldConstant(),\n    ]\n)\nrelay_mod[\"main\"] = bind_params_by_name(relay_mod[\"main\"], relay_params)\n# timing_inst is put after call_node_inst.\n# So the execution time of ``call_node.inst.run_after_pass()`` is also counted.\nwith tvm.transform.PassContext(opt_level=3, instruments=[call_node_inst, timing_inst]):\n    relay_mod = pass_seq(relay_mod)\n    profiles = timing_inst.render()\n# Uncomment the next line to see timing-profile results.\n# print(profiles)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can see how many CallNode increase/decrease per op type.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pprint import pprint\n\nprint(\"Printing the change in number of occurrences of each operator caused by each pass...\")\npprint(call_node_inst.get_pass_to_op_diff())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Exception Handling\n------------------\nLet's see what happens if an exception occurs in a method of a ``PassInstrument``.\n\nDefine ``PassInstrument`` classes which raise exceptions in enter/exit ``PassContext``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class PassExampleBase:\n    def __init__(self, name):\n        self._name = name\n\n    def enter_pass_ctx(self):\n        print(self._name, \"enter_pass_ctx\")\n\n    def exit_pass_ctx(self):\n        print(self._name, \"exit_pass_ctx\")\n\n    def should_run(self, mod, info):\n        print(self._name, \"should_run\")\n        return True\n\n    def run_before_pass(self, mod, pass_info):\n        print(self._name, \"run_before_pass\")\n\n    def run_after_pass(self, mod, pass_info):\n        print(self._name, \"run_after_pass\")\n\n\n@pass_instrument\nclass PassFine(PassExampleBase):\n    pass\n\n\n@pass_instrument\nclass PassBadEnterCtx(PassExampleBase):\n    def enter_pass_ctx(self):\n        print(self._name, \"bad enter_pass_ctx!!!\")\n        raise ValueError(\"{} bad enter_pass_ctx\".format(self._name))\n\n\n@pass_instrument\nclass PassBadExitCtx(PassExampleBase):\n    def exit_pass_ctx(self):\n        print(self._name, \"bad exit_pass_ctx!!!\")\n        raise ValueError(\"{} bad exit_pass_ctx\".format(self._name))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If an exception occurs in ``enter_pass_ctx``, ``PassContext`` will disable the pass\ninstrumentation. And it will run the ``exit_pass_ctx`` of each ``PassInstrument``\nwhich successfully finished ``enter_pass_ctx``.\n\nIn following example, we can see ``exit_pass_ctx`` of `PassFine_0` is executed after exception.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "demo_ctx = tvm.transform.PassContext(\n    instruments=[\n        PassFine(\"PassFine_0\"),\n        PassBadEnterCtx(\"PassBadEnterCtx\"),\n        PassFine(\"PassFine_1\"),\n    ]\n)\ntry:\n    with demo_ctx:\n        relay_mod = relay.transform.InferType()(relay_mod)\nexcept ValueError as ex:\n    print(\"Catching\", str(ex).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Exceptions in ``PassInstrument`` instances cause all instruments of the current ``PassContext``\nto be cleared, so nothing is printed when ``override_instruments`` is called.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "demo_ctx.override_instruments([])  # no PassFine_0 exit_pass_ctx printed....etc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If an exception occurs in ``exit_pass_ctx``, then the pass instrument is disabled.\nThen exception is propagated. That means ``PassInstrument`` instances registered\nafter the one throwing the exception do not execute ``exit_pass_ctx``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "demo_ctx = tvm.transform.PassContext(\n    instruments=[\n        PassFine(\"PassFine_0\"),\n        PassBadExitCtx(\"PassBadExitCtx\"),\n        PassFine(\"PassFine_1\"),\n    ]\n)\ntry:\n    # PassFine_1 execute enter_pass_ctx, but not exit_pass_ctx.\n    with demo_ctx:\n        relay_mod = relay.transform.InferType()(relay_mod)\nexcept ValueError as ex:\n    print(\"Catching\", str(ex).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Exceptions occured in ``should_run``, ``run_before_pass``, ``run_after_pass``\nare not handled explicitly -- we rely on the context manager (the ``with`` syntax)\nto exit ``PassContext`` safely.\n\nWe use ``run_before_pass`` as an example:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@pass_instrument\nclass PassBadRunBefore(PassExampleBase):\n    def run_before_pass(self, mod, pass_info):\n        print(self._name, \"bad run_before_pass!!!\")\n        raise ValueError(\"{} bad run_before_pass\".format(self._name))\n\n\ndemo_ctx = tvm.transform.PassContext(\n    instruments=[\n        PassFine(\"PassFine_0\"),\n        PassBadRunBefore(\"PassBadRunBefore\"),\n        PassFine(\"PassFine_1\"),\n    ]\n)\ntry:\n    # All exit_pass_ctx are called.\n    with demo_ctx:\n        relay_mod = relay.transform.InferType()(relay_mod)\nexcept ValueError as ex:\n    print(\"Catching\", str(ex).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Also note that pass instrumentation is not disable. So if we call\n``override_instruments``, the ``exit_pass_ctx`` of old registered ``PassInstrument``\nis called.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "demo_ctx.override_instruments([])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If we don't wrap pass execution with ``with`` syntax, ``exit_pass_ctx`` is not\ncalled. Let try this with current ``PassContext``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cur_pass_ctx = tvm.transform.PassContext.current()\ncur_pass_ctx.override_instruments(\n    [\n        PassFine(\"PassFine_0\"),\n        PassBadRunBefore(\"PassBadRunBefore\"),\n        PassFine(\"PassFine_1\"),\n    ]\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then call passes. ``exit_pass_ctx`` is not executed after the exception,\nas expectation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "try:\n    # No ``exit_pass_ctx`` got executed.\n    relay_mod = relay.transform.InferType()(relay_mod)\nexcept ValueError as ex:\n    print(\"Catching\", str(ex).split(\"\\n\")[-1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Clear instruments.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cur_pass_ctx.override_instruments([])"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}